import os
import argparse
import torch
import json
import numpy as np
from collections import defaultdict, OrderedDict
from image_synthesis.utils.misc import instantiate_from_config
from image_synthesis.modeling.modules.clip.simple_tokenizer import SimpleTokenizer
from image_synthesis.utils.io import save_config_to_yaml

def conceptual_caption(data_root='', save_dir='data/captions'):
    val_config = {
        "target": "image_synthesis.data.tsv_dataset.TSVImageTextDataset",
        "params": {
            'data_root': data_root,
            'name': 'conceptualcaption/val',
            "image_tsv_file": ['gcc-val-image.tsv'],
            "text_tsv_file": ['gcc-val-text.tsv'],
            "text_format": "json",
            "im_preprocessor_config": {
                "target": "image_synthesis.data.utils.image_preprocessor.SimplePreprocessor",
                "params": {
                    "size": 256,
                },
            },
        },
    }

    train_config = {
        "target": "image_synthesis.data.tsv_dataset.TSVImageTextDataset",
        "params": {
            'data_root': data_root,
            'name': 'conceptualcaption/train',
            "image_tsv_file":['gcc-train-image-00.tsv','gcc-train-image-01.tsv'],
            "text_tsv_file": ['gcc-train-text-00.tsv', 'gcc-train-text-01.tsv'],
            "text_format": "json",
            "im_preprocessor_config": {
                "target": "image_synthesis.data.utils.image_preprocessor.SimplePreprocessor",
                "params": {
                    "size": 256,
                },
            },
        },
    }

    val_dataset = instantiate_from_config(val_config)
    train_dataset = instantiate_from_config(train_config)
    datasets = [val_dataset, train_dataset]

    # datasets = [val_dataset]

    word_freq = defaultdict(int)
    batch_size = 8
    for dataset in datasets:
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )

        for i, data in enumerate(dataloader):
            if (i*batch_size) % 100 == 0:
                print('{}/{}...'.format(i*batch_size, len(dataset)))
            captions = data['text']
            for cap in captions:
                words = set(cap.lower().split())
                for word in words:
                    word_freq[word] += 1

    words = []
    count = []
    for k, v in word_freq.items():
        words.append(k)
        count.append(v)
    # sort according to count
    word_freq_ = OrderedDict()
    index = np.argsort(count)
    for idx in range(len(index)-1, -1, -1):
        idx = index[idx]
        word_freq_[words[idx]] = word_freq[words[idx]]
    word_freq = word_freq_

    save_file = os.path.join(save_dir, 'conceptual_caption.yaml')
    print('saved in {}'.format(save_file))
    os.makedirs(save_dir, exist_ok=True)
    save_config_to_yaml(word_freq, path=save_file)

def get_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--data_root', type=str, default='data', 
                        help='dir of datasets')
    parser.add_argument('--save_dir', type=str, default='data/world_frequency', 
                        help='dir to save captions')
    

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    conceptual_caption(data_root=args.data_root, save_dir=args.save_dir)
    # statics_of_captions('data/captions/conceptual_caption.txt', subword_end_idx=16384)